# Merging with AP3 model 
library(pacman)
p_load(
  here, data.table, fst, readxl, stringr, collapse, dplyr
)

# Goal: Calculate marginal damages for each power plant 
# 1) Load emissions rates for each plant (so2, co2, nox, pm2.5)
# 2) Merge plants to marginal damages per ton of pollutant (using fips and stack height)
# 3) Multiply emissions rates by marginal damages

# TODO: stack heights to median of fuel type for missing observations
calc_marginal_damages = function(
    SOCIAL_COST_OF_CARBON = 51*0.7993 # 51 in 2022 dollars to 2014
  ){
# Loading Data ----------------------------------------------------------------
  print('Loading data')
  # Unit information table, taking most recent year for each plant
  oge_plant_info_dt =
    read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
    )
  oge_plant_info_unique_dt = 
    oge_plant_info_dt[, 
      .SD[which.max(year)], 
      by = .(plant_id_eia)
    ]
  # Now the emissions rates 
  emissions_rate_dt = 
    merge(
      read.fst(
        here('Data/electricity-generation/plant-model-fit/emissions-rate-dt.fst'),
        as.data.table = TRUE
      ),
      oge_plant_info_unique_dt,
      by = 'plant_id_eia'
    )[,.(
      plant_id_eia, 
      fips, 
      pollutant, 
      # For now we only have low/medium damages.
      stack_height = fcase(
        stack_height < 250, 'low',
        stack_height >= 250, 'medium'#, & stack_height <= 500,
        #stack_height > 500, 'tall'
      ),
      output, 
      net_generation_mwh_split, 
      #tot_gload,
      # Converting emissions rates to tons/mwh, so2+nox+co2 in lbs
      emissions_rate_tons = emissions_rate/2000
    )]
  # Reading in pm2.5 emissions rates 
  avg_emissions_dt = fread(
    here('Data/electricity-generation/output/emissions-rate-avg-oge.csv')
  )
# Calculating weights for shaped plants ---------------------------------------
  print('shaped plant weights')
  # First loading plant id's we have elec model for 
  modeled_plant_id_eia_dt =
    read.fst(
      path = here("Data/electricity-generation/plant-model-fit/elec-gen-fit-dt.fst"),
      as.data.table = TRUE
    )[,.(plant_id_eia)] |> unique()
  shaped_plants = modeled_plant_id_eia_dt[plant_id_eia >= 900000]$plant_id_eia
  regular_plants = modeled_plant_id_eia_dt[plant_id_eia < 900000]$plant_id_eia
  # Loading monthly plant generation data 
  monthly_plant_dt = 
    rbind(
      fread(here(
        'Data/electricity-generation/open-grid-emissions/montly-plant-data/plant_data_2019.csv'
      ))[,year:=2019],
      fread(here(
        'Data/electricity-generation/open-grid-emissions/montly-plant-data/plant_data_2020.csv'
      ))[,year:=2020]
    )
  # Calculating % of fleet that each individual plant produces
  shaped_plant_weights =
    merge(
      monthly_plant_dt, 
      oge_plant_info_dt,
      by = c('plant_id_eia','year')
    )[shaped_plant_id %in% shaped_plants
      &!(plant_id_eia %in% regular_plants),
      .(tot_net_gen = sum(net_generation_mwh)),
      keyby = .(plant_id_eia, shaped_plant_id)
    ][,wt_net_gen := tot_net_gen/sum(tot_net_gen),
      by = shaped_plant_id
    ][,wt_net_gen := ifelse(tot_net_gen == 0, 0, wt_net_gen)]
  print(shaped_plant_weights[,
    .(tot = sum(wt_net_gen)), 
    by = shaped_plant_id
  ][,.(wt_not_one = mean(tot < 0.99999 | tot > 1.000001))])
  print('Some shaped plants have no production...')
  
# Adding PM2.5 to rest of pollutants ------------------------------------------
  # First for the CEMS plants
  emissions_rate_dt = 
    rbind(
      emissions_rate_dt[pollutant != 'pm25'],
      merge(
        emissions_rate_dt[,.(
          plant_id_eia,
          fips, stack_height, 
          output, net_generation_mwh_split
        )] |>unique(),
        avg_emissions_dt[,.(
          plant_id_eia,
          pollutant = 'pm25',
          emissions_rate_tons = pm25_tons_per_mwh
        )],
        by = c('plant_id_eia')
      ),
      use.names = TRUE
    )
  # Now putting together data for the shaped plants 
  shaped_emissions_rate_dt = 
    melt(
      avg_emissions_dt[shaped == TRUE],
      id.var = 'plant_id_eia',
      measure = patterns('tons_per_mwh')
    ) |>
    merge(
      oge_plant_info_unique_dt,
      by = 'plant_id_eia'
    ) %>% 
    .[,.(
      plant_id_eia, 
      fips, 
      pollutant = str_remove(variable, '_tons_per_mwh'), 
      stack_height = fcase(
        stack_height < 250, 'low',
        stack_height >= 250, 'medium',#, & stack_height <= 500,
        #stack_height > 500, 'tall'
        default = 'low'
      ),
      output = 'emissions_rate_low',
      net_generation_mwh_split = 1,
      emissions_rate_tons = value
    )]
  shaped_emissions_rate_dt = 
    rbind(
      shaped_emissions_rate_dt[output == 'emissions_rate_low'],
      shaped_emissions_rate_dt[output == 'emissions_rate_low'] |>
        mutate(output = 'emissions_rate_high')
    )
  # Adding to rest of data 
  emissions_rate_dt = 
    rbind(
      emissions_rate_dt,
      shaped_emissions_rate_dt[!(plant_id_eia %in% unique(emissions_rate_dt$plant_id_eia))]
    )
  
# Now for marginal damages ----------------------------------------------------
  print('damages')
  fips_dt = read_xlsx(
    path = here('Data/Marginal Damages (2011) from Holland, Mansur, Muller, Yates AER forthcoming.xlsx'),
    sheet = 'ground-level MD per ton'
    ) %>%  
    data.table() %>% 
    .[,.( # Miami-Dade changed their fips code
      fips = fcase(
        fips == '12025', '12086',
        fips != '12025', str_pad(fips, 5,'left','0')
      )
    )]
  # TODO: Figure out county list for tall stack heights
  md_dt = 
    rbind(
      cbind(
        fips_dt,
        fread(here("Data/AP3/JointFolder/md_L_2014_CS.csv"))[,stack_height := 'low']
      ),
      cbind(
        fips_dt,
        fread(here("Data/AP3/JointFolder/md_M_2014_CS.csv"))[,stack_height := 'medium']
      )
    ) |> 
    setnames(new = c('fips','nh3','nox','pm25','so2','voc','stack_height')) |>
    melt(
      id.vars = c('fips','stack_height'),
      variable.name = 'pollutant',
      value.name = 'marginal_damage'
    )
  # Merging emissions data
  emissions_md_dt_all = 
    merge(
      emissions_rate_dt, 
      md_dt,
      by = c('fips','stack_height','pollutant'),
      all.x = TRUE
    )
  # Adding social cost of carbon 
  emissions_md_dt_all[,marginal_damage := fcase(
    pollutant == 'co2e', SOCIAL_COST_OF_CARBON,
    pollutant != 'co2e', marginal_damage
  )]
  # Aggregating shaped plants up to the fleet level 
  emissions_md_dt = 
    merge(
      emissions_md_dt_all, 
      shaped_plant_weights,
      by = c('plant_id_eia'),
      all.x = TRUE
    )[, # Setting weights = 1 if they are missing from
      wt_net_gen := ifelse(is.na(wt_net_gen), 1, wt_net_gen)
    ][,.( # Calculating weighted marginal damages (weights only really apply for shaped)
        md_per_mwh = sum(emissions_rate_tons*marginal_damage*wt_net_gen),
        emissions_rate_tons = sum(emissions_rate_tons*wt_net_gen)),
      keyby = .(
        plant_id_eia = ifelse(
          is.na(shaped_plant_id), plant_id_eia, 
          shaped_plant_id),
        output, 
        net_generation_mwh_split,
        pollutant 
    )]
  # Saving for later
  write.fst(
    emissions_md_dt, 
    path = here('Data/electricity-generation/output/emissions-md-dt.fst')
  )
  # Casting into wide format
  emissions_md_dt_wide = 
    dcast(
      emissions_md_dt,
      plant_id_eia + net_generation_mwh_split ~ output + pollutant,
      value.var = 'md_per_mwh'
    )
  setnames(
    emissions_md_dt_wide,
    new = colnames(emissions_md_dt_wide) |> str_replace('emissions_rate','md_per_mwh')
  )
  # Aggregating over all of the pollutants
  tot_md_dt_wide = emissions_md_dt[,.(
    damages_per_mwh = sum(md_per_mwh, na.rm =TRUE)),
    keyby = .(plant_id_eia, output, net_generation_mwh_split)
  ] 
  tot_md_dt = 
    dcast(
      tot_md_dt_wide,
      plant_id_eia ~ output,
      value.var = 'damages_per_mwh'
    ) |>
    setnames(
      old = c('emissions_rate_high','emissions_rate_low'),
      new = c('md_per_mwh_high','md_per_mwh_low')
    ) |>
    merge(
      emissions_md_dt_wide,
      by = c('plant_id_eia')
    )
  # Saving the results
  write.csv(
    tot_md_dt, 
    here('Data/electricity-generation/output/damage-per-mwh-oge.csv')
  )
  print('done')
}




